package com.flow.framework.cipher.algorithm;

import com.flow.framework.cipher.algorithm.version.RsaVersion;
import com.flow.framework.cipher.pojo.bo.CipherKeyPairBo;
import com.flow.framework.common.error.SystemErrorCode;
import com.flow.framework.common.exception.CheckedException;
import com.flow.framework.common.stream.handler.BatchVoidProcessHandler;
import com.flow.framework.common.util.io.IoUtil;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.codec.binary.Base64;

import javax.crypto.Cipher;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;

/**
 * RSA封装
 *
 * @author luoguopiao
 * @version 0.0.1
 * @date 2022/1/16
 */
@Slf4j
class Rsa {

    /**
     * 加密名称RSA
     */
    private static final String RSA = "RSA";

    /**
     * 创建公私钥
     *
     * @return 公私钥对象
     */
    static CipherKeyPairBo generateKey(RsaVersion version) {
        try {
            KeyPairGenerator generator = KeyPairGenerator.getInstance(RSA);

            //初始化KeyPairGenerator对象,密钥长度
            generator.initialize(version.getKeySize());

            //生成公私钥对
            KeyPair keyPair = generator.generateKeyPair();
            PrivateKey privateKey = keyPair.getPrivate();
            PublicKey publicKey = keyPair.getPublic();
            return new CipherKeyPairBo(
                    Base64.encodeBase64String(publicKey.getEncoded()), Base64.encodeBase64String(privateKey.getEncoded())
            );
        } catch (NoSuchAlgorithmException e) {
            log.error("generate key failed.", e);
            throw new CheckedException(SystemErrorCode.CIPHER_ERROR, "generate key failed.", e);
        }
    }

    /**
     * 加密字符串
     *
     * @param version      version
     * @param content      content
     * @param publicKeyStr publicKeyStr
     * @return 密文
     */
    static String encrypt(RsaVersion version, String content, String publicKeyStr) {
        if (null == content || null == publicKeyStr) {
            log.error("encrypt error. params can't be empty.");
            throw new CheckedException(SystemErrorCode.PARAMS_ERROR);
        }
        byte[] contentByte = content.getBytes(StandardCharsets.UTF_8);
        byte[] encryptByte = encrypt(version, contentByte, publicKeyStr);
        return Base64.encodeBase64String(encryptByte);
    }

    /**
     * 加密byte数组
     *
     * @param version      version
     * @param contentBytes contentBytes
     * @param publicKeyStr publicKeyStr
     * @return 密文
     */
    static byte[] encrypt(RsaVersion version, byte[] contentBytes, String publicKeyStr) {
        if (null == contentBytes || null == publicKeyStr) {
            log.error("encrypt error. params can't be empty.");
            throw new CheckedException(SystemErrorCode.PARAMS_ERROR);
        }
        try {
            //转换公钥
            X509EncodedKeySpec keySpec = new X509EncodedKeySpec(Base64.decodeBase64(publicKeyStr));
            KeyFactory keyFactory = KeyFactory.getInstance(RSA);
            PublicKey publicKey = keyFactory.generatePublic(keySpec);

            //对数据加密
            Cipher cipher = Cipher.getInstance(version.getCipherAlgorithm());
            cipher.init(Cipher.ENCRYPT_MODE, publicKey);
            return AlgorithmUtil.cipherDoFinal(contentBytes, cipher, version.getEncryptBlockSize());
        } catch (Exception e) {
            log.error("encrypt error.", e);
            throw new CheckedException(SystemErrorCode.CIPHER_ERROR, "encrypt error.", e);
        }
    }


    /**
     * 解密字符串
     *
     * @param version       version
     * @param content       content
     * @param privateKeyStr privateKeyStr
     * @return 原文
     */
    static String decrypt(RsaVersion version, String content, String privateKeyStr) {
        if (null == content || null == privateKeyStr) {
            log.error("decrypt error. params can't be empty.");
            throw new CheckedException(SystemErrorCode.PARAMS_ERROR);
        }
        byte[] contentByte = Base64.decodeBase64(content);
        byte[] decryptedData = decrypt(version, contentByte, privateKeyStr);
        return new String(decryptedData, StandardCharsets.UTF_8);
    }

    /**
     * 解密byte数组
     *
     * @param version       version
     * @param contentBytes  contentBytes
     * @param privateKeyStr privateKeyStr
     * @return 原文
     */
    static byte[] decrypt(RsaVersion version, byte[] contentBytes, String privateKeyStr) {
        if (null == contentBytes || null == privateKeyStr) {
            log.error("decrypt error. params can't be empty.");
            throw new CheckedException(SystemErrorCode.PARAMS_ERROR);
        }
        try {
            PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(Base64.decodeBase64(privateKeyStr));
            KeyFactory keyFactory = KeyFactory.getInstance(RSA);
            PrivateKey privateKey = keyFactory.generatePrivate(keySpec);
            Cipher cipher = Cipher.getInstance(version.getCipherAlgorithm());
            cipher.init(Cipher.DECRYPT_MODE, privateKey);
            return AlgorithmUtil.cipherDoFinal(contentBytes, cipher, version.getDecryptBlockSize());
        } catch (Exception e) {
            log.error("decrypt error.", e);
            throw new CheckedException(SystemErrorCode.CIPHER_ERROR, "decrypt error.", e);
        }
    }

    /**
     * 使用私钥对数据进行签名
     *
     * @param version    version
     * @param content    content
     * @param privateKey privateKey
     * @return 签名
     */
    static String sign(RsaVersion version, String content, String privateKey) {
        try {
            PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(Base64.decodeBase64(privateKey));
            KeyFactory keyFactory = KeyFactory.getInstance(RSA);
            PrivateKey privateK = keyFactory.generatePrivate(keySpec);
            Signature signature = Signature.getInstance(version.getSignatureAlgorithm());
            signature.initSign(privateK);
            byte[] data = content.getBytes(StandardCharsets.UTF_8);
            signature.update(data);
            return Base64.encodeBase64String(signature.sign());
        } catch (Exception e) {
            log.error("sign error.", e);
            throw new CheckedException(SystemErrorCode.CIPHER_ERROR, "sign error.", e);
        }
    }

    /**
     * 使用私钥对数据进行签名
     *
     * @param version     version
     * @param inputStream inputStream
     * @param privateKey  privateKey
     * @return 签名
     */
    static String sign(RsaVersion version, InputStream inputStream, String privateKey) {
        try {
            PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(Base64.decodeBase64(privateKey));
            KeyFactory keyFactory = KeyFactory.getInstance(RSA);
            PrivateKey privateK = keyFactory.generatePrivate(keySpec);
            Signature signature = Signature.getInstance(version.getSignatureAlgorithm());
            signature.initSign(privateK);
            update(inputStream, signature);
            return Base64.encodeBase64String(signature.sign());
        } catch (Exception e) {
            log.error("sign error.", e);
            throw new CheckedException(SystemErrorCode.CIPHER_ERROR, "sign error.", e);
        }
    }

    private static void update(InputStream inputStream, Signature signature) {
        IoUtil.handleInputStream(inputStream, new BatchVoidProcessHandler() {
            @Override
            public int getBatchSize() {
                return 4196;
            }

            @Override
            public void process(byte[] buffer) {
                try {
                    signature.update(buffer);
                } catch (Exception e) {
                    log.error("signature update error.", e);
                    throw new CheckedException(SystemErrorCode.UNEXPECTED_ERROR, "signature update error", e);
                }
            }
        }, exception -> {
            log.error("sign error.", exception);
            throw new CheckedException(SystemErrorCode.UNEXPECTED_ERROR, "sign error.", exception);
        });
    }

    /**
     * 使用公钥对私钥的签名数据进行验证
     *
     * @param version   version
     * @param content   content
     * @param publicKey publicKey
     * @param sign      私钥加签结果
     * @return 是否验签成功
     */
    static boolean verify(RsaVersion version, String content, String publicKey, String sign) {
        try {
            X509EncodedKeySpec keySpec = new X509EncodedKeySpec(Base64.decodeBase64(publicKey));
            KeyFactory keyFactory = KeyFactory.getInstance(RSA);
            PublicKey publicK = keyFactory.generatePublic(keySpec);
            Signature signature = Signature.getInstance(version.getSignatureAlgorithm());
            signature.initVerify(publicK);
            byte[] data = content.getBytes(StandardCharsets.UTF_8);
            signature.update(data);
            return signature.verify(Base64.decodeBase64(sign));
        } catch (Exception e) {
            log.error("verify sign error.", e);
            throw new CheckedException(SystemErrorCode.CIPHER_ERROR, "verify sign error.", e);
        }
    }


    /**
     * 使用公钥对私钥的签名数据进行验证
     *
     * @param version     version
     * @param inputStream inputStream
     * @param publicKey   publicKey
     * @param sign        私钥加签结果
     * @return 是否验签成功
     */
    static boolean verify(RsaVersion version, InputStream inputStream, String publicKey, String sign) {
        try {
            X509EncodedKeySpec keySpec = new X509EncodedKeySpec(Base64.decodeBase64(publicKey));
            KeyFactory keyFactory = KeyFactory.getInstance(RSA);
            PublicKey publicK = keyFactory.generatePublic(keySpec);
            Signature signature = Signature.getInstance(version.getSignatureAlgorithm());
            signature.initVerify(publicK);
            update(inputStream, signature);
            return signature.verify(Base64.decodeBase64(sign));
        } catch (Exception e) {
            log.error("verify sign error.", e);
            throw new CheckedException(SystemErrorCode.CIPHER_ERROR, "verify sign error.", e);
        }
    }
}
